from tool.args import get_general_args
from tool.util import init_wandb
from train.mlbase import MLBase
from evaluate.evaluator import Evaluator
import torch.nn.functional as F

from data.dl_getter import DATASETS, n_cls, sh, input_range
import pandas as pd
import argparse
import numpy as np
import sys

import torch
import torch.nn as nn
import torchvision as tv
import torchvision.transforms as tr
from tool.util import set_seed, bool_flag
from datetime import datetime
import os
from data.ds import Non_dataset
from data.ds import ood_root
from data.dl_getter import get_transform
from torch.utils.data import DataLoader
import pickle
import matplotlib.pyplot as plt

import abc
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
import math


class Kernel(abc.ABC, nn.Module):
    """Base class which defines the interface for all kernels."""

    def __init__(self, bandwidth):
        """Initializes a new Kernel.

        Args:
            bandwidth: The kernel's (band)width.
        """
        super().__init__()
        self.bandwidth = bandwidth

    def _diffs(self, test_Xs, train_Xs):
        """Computes difference between each x in test_Xs with all train_Xs."""
        # test_Xs = test_Xs.view((test_Xs.shape[0], 1, *test_Xs.shape[1:]))
        # train_Xs = train_Xs.view((1, train_Xs.shape[0], *train_Xs.shape[1:]))
        test_Xs = test_Xs.unsqueeze(1)
        train_Xs = train_Xs.unsqueeze(0)
        return test_Xs - train_Xs

    @abc.abstractmethod
    def forward(self, test_Xs, train_Xs):
        """Computes log p(x) for each x in test_Xs given train_Xs."""

    @abc.abstractmethod
    def sample(self, train_Xs):
        """Generates samples from the kernel distribution."""


class GaussianKernel(Kernel):
    """Implementation of the Gaussian kernel."""

    def forward(self, test_Xs, train_Xs):
        n, d = train_Xs.shape
        n, h = torch.tensor(
            n, dtype=torch.float32).cuda(), torch.tensor(self.bandwidth).cuda()
        pi = torch.tensor(np.pi).cuda()
        Z = 0.5 * d * torch.log(2 * pi) + d * torch.log(h) + torch.log(n)
        diffs = self._diffs(test_Xs, train_Xs) / h
        log_exp = -0.5 * torch.norm(diffs, p=2, dim=-1) ** 2
        return torch.logsumexp(log_exp - Z, dim=-1)

    def sample(self, train_Xs):
        noise = torch.randn(train_Xs.shape) * self.bandwidth
        return train_Xs + noise

class KDE:
    """The KernelDensityEstimator model."""
    def __init__(self, train_Xs, std=0.1):
        """Initializes a new KernelDensityEstimator.

        Args:
            train_Xs: The "training" data to use when estimating probabilities.
            kernel: The kernel to place on each of the train_Xs.
        """
        super().__init__()
        self.kernel = GaussianKernel(bandwidth=std)
        self.train_Xs = train_Xs
        assert len(self.train_Xs.shape) == 2, "Input cannot have more than two axes."

    def __call__(self, x):
        return self.kernel(x, self.train_Xs)


def load_data(path):
    with open(path, 'rb') as f:
        latents = pickle.load(f)
    return latents


@torch.no_grad()
def check_acc(model, vl_dl):
    model.eval()
    correct = 0
    total = 0
    for x, y in vl_dl:
        x, y = x.cuda(), y.cuda()
        out = model(x)
        _, pred = torch.max(out.data, 1)
        total += y.size(0)
        correct += (pred == y).sum().item()
    print(f"acc : {correct / total}")


def get_dist(data, centeroid, selected_index):
    # data = dl.dataset.data.cuda()
    data = data[selected_index]

    dist = torch.stack(
        [torch.dist(vector, centeroid) \
            for vector in data]
    )
    return dist


def dist_to_point(dist1, dist2, dist3):
    degree = math.degrees(
        math.acos((dist2**2+dist3**2-dist1**2) / (2*dist2*dist3)))

    degree = degree / 180 * math.pi

    x = dist2 * math.cos(degree) - 0.5
    y = dist2 * math.sin(degree)
    return x, y


def scatter_plot(scatter_x, scatter_y, scatter_x2, scatter_y2, labels):
    plt.scatter(-0.5, 0, color='blue')
    plt.scatter(0.5, 0, color='red')

    plt.scatter(scatter_x, scatter_y, label=labels[0], alpha=0.3)
    plt.scatter(scatter_x.mean(), scatter_y.mean(), color='blue')
    plt.scatter(scatter_x2, scatter_y2, label=labels[1], alpha=0.3)
    plt.scatter(scatter_x2.mean(), scatter_y2.mean(), color='red')
    
    # plt.ylim(-0.05, 1.2)
    # plt.xlim(-1.0, 1.5)
    plt.legend()
    # plt.show()


def main(eval):
    model = eval.model
    tr_dl = eval.tr_dl
    vl_dl = eval.vl_dl
    check_acc(model, vl_dl)

    #hist
    with open('ad_sample/pgd_0.03_0.pkl', 'rb') as f:
        dict_ = pickle.load(f)
        data = dict_['data']
        lbls = dict_['label']

    x_latent = load_data('ad_sample/pgd_0.03_0_latent.pkl')
    advp_latent = load_data('ad_sample/pgd_0.03_20_latent.pkl')

    n_cls = 10
    std = 3
    bins = 10
    bin_step = 0.02
    bin_start = 0
    bin_end = 5.3
    x_estimates = []
    adv_estimates = []
    rp_estimates = []
    min_value = []
    max_value = []

    kde = KDE(x_latent, std=7)
    zz = x_latent
    x_esti = [kde(zp) for zp in tqdm(torch.split(zz, 2))]
    x_esti = torch.cat(x_esti)

    softmax = nn.Softmax(dim=1)
    space = np.linspace(x_esti.min().cpu(), x_esti.max().cpu(), 21)
    lows = space[:-1]
    highs = space[1:]

    line_x, line_y = [], []
    total_prediction = []
    with torch.no_grad():
        for low, high in zip(lows, highs):
            mask = (x_esti >= low) & (x_esti < high)
            tmp_latent = x_latent[mask]
            preds, predictions = [], []
            its = int(tmp_latent.shape[0]/100)
            if its==0: 
                pred = eval.model.head(tmp_latent)
                pred = softmax(pred)
                preds.append(pred.max(1)[0].cpu())
                predictions.append(pred.argmax(1))
            else:
                for it in range(its):
                    tmp = tmp_latent[it*100 : (it+1)*100]
                    pred = eval.model.head(tmp)
                    pred = softmax(pred)
                    preds.append(pred.max(1)[0].cpu())
                    predictions.append(pred.argmax(1))
            predictions = torch.cat(predictions)
            preds = torch.cat(preds)
            line_x.append(high)
            line_y.append(preds.mean().item())
            total_prediction.append(predictions)
    zz = advp_latent
    tt = [kde(zp) for zp in tqdm(torch.split(zz, 2))]
    tt = torch.cat(tt)
    
    softmax = nn.Softmax(dim=1)
    space = np.linspace(tt.min().cpu(), tt.max().cpu(), 21)
    lows = space[:-1]
    highs = space[1:]

    adv_total_predictions = []
    line_x_adv, line_y_adv = [], []
    with torch.no_grad():
        for low, high in zip(lows, highs):
            mask = (tt >= low) & (tt < high)
            tmp_latent = advp_latent[mask]
            preds, adv_predictions = [], []
            its = int(tmp_latent.shape[0]/100)

            if its==0: 
                pred = eval.model.head(tmp_latent)
                pred = softmax(pred)
                preds.append(pred.max(1)[0].cpu())
                adv_predictions.append(pred.argmax(1))
            else:
                for it in range(its):
                    tmp = tmp_latent[it*100 : (it+1)*100]
                    pred = eval.model.head(tmp)
                    pred = softmax(pred)
                    preds.append(pred.max(1)[0].cpu())
                    adv_predictions.append(pred.argmax(1))
            adv_predictions = torch.cat(adv_predictions)
            preds = torch.cat(preds)
            line_x_adv.append(high)
            line_y_adv.append(preds.mean().item())
            adv_total_predictions.append(adv_predictions)
    
    line_x = [tmp - x_esti.max().item() for tmp in line_x]
    line_x_adv = [tmp - x_esti.max().item() for tmp in line_x_adv]
    
    # plt.rcParams["font.family"] = "Times New Roman"
    fontsize=53
    fig, ax1 = plt.subplots(figsize=(10, 6))

    ax1.hist((x_esti - x_esti.max()).cpu(), label='ID samples', alpha=0.3, bins=100)
    ax1.hist((tt - x_esti.max()).cpu(), label='adv. samples', alpha=0.3, bins=100)
    ax1.set_yticks([])
    # ax1.set_ylabel('Frequency', fontsize=fontsize)
    ax1.tick_params(labelsize=fontsize)

    ax2 = ax1.twinx()
    ax2.plot(line_x, line_y, linewidth=3, alpha=0.5, color='dodgerblue')
    ax2.plot(line_x_adv, line_y_adv,linewidth=3, alpha=0.5, color='chocolate')
    # ax2.set_ylim(0.7, 1.01)
    ax2.tick_params(axis='y', labelsize=fontsize)
    # ax1.set_xlabel('(kernel) Density', fontsize=fontsize)
    plt.xticks(fontsize=fontsize)
    plt.tight_layout()
    plt.savefig('ad_sample/hist.png')
    plt.clf()

    #scatter
    with open('ad_sample/pgd_0.03_0.pkl', 'rb') as f:
        lbls = pickle.load(f)['label']
    latent_step_0 = load_data('ad_sample/pgd_0.03_0_latent.pkl')
    latent_step_1 = load_data('ad_sample/pgd_0.03_1_latent.pkl')
    latent_step_2 = load_data('ad_sample/pgd_0.03_2_latent.pkl')
    latent_step_4 = load_data('ad_sample/pgd_0.03_4_latent.pkl')
    latent_step_8 = load_data('ad_sample/pgd_0.03_8_latent.pkl')
    latent_step_16 = load_data('ad_sample/pgd_0.03_16_latent.pkl')
    latent_step_32 = load_data('ad_sample/pgd_0.03_32_latent.pkl')

    dfs = []
    centroids = [latent_step_0[lbls==lbl].mean(axis=0) \
                    for lbl in range(10)]

    with torch.no_grad():
        with open('ad_sample/pgd_0.03_32.pkl', 'rb') as f:
            dict_ = pickle.load(f)
            data = dict_['data']
            label = dict_['label']

        adv_lbls = []
        for it in range(100):
            tmp = data[it*100 : (it+1)*100].cuda()
            pred = model(tmp)
            pred = pred.argmax(dim=1)
            adv_lbls.append(pred)
        adv_lbls = torch.cat(adv_lbls)

    for start_label in range(10):
        for end_label in range(10):
            select_index = (lbls == start_label) * (adv_lbls == end_label)
            if start_label == end_label:
                continue
            if select_index.sum() == 0:
                print(f"start_label = {start_label}, end_label = {end_label} is empty")
                continue
            
            print(f"start_label = {start_label}, end_label = {end_label} is {select_index.sum()}")
            # print(f"start_label = {start_label}, end_label = {end_label} is {select_index.sum()}")
            df = pd.DataFrame()
            keys = ['step_0', 'step_1', 'step_2', 'step_4', 'step_8', 'step_16', 'step_32']
            loaders = [latent_step_0, latent_step_1, latent_step_2, latent_step_4, latent_step_8, latent_step_16, latent_step_32]
            for key, loader in zip(keys, loaders):
                original_dist = get_dist(loader, centroids[start_label], select_index)
                adv_dist = get_dist(loader, centroids[end_label], select_index)
                df[f'{key}'] = [
                    original_dist.mean().cpu().numpy(), 
                    adv_dist.mean().cpu().numpy()]
            
            centers_interval = torch.dist(centroids[start_label], centroids[end_label])
            df['centers_interval'] = centers_interval.cpu().numpy()
            dfs.append(df)

    scatter_x, scatter_y = [], []
    for df_index, df in enumerate(dfs):
        df = df / df['centers_interval'][0]    

        advcenter_to_sample = np.stack(df.loc[1][:-1].to_numpy())
        oricenter_to_sample = np.stack(df.loc[0][:-1].to_numpy())
        center_to_center = 1

        degree = np.degrees(
            np.arccos(
                (oricenter_to_sample**2 + center_to_center**2 - advcenter_to_sample**2) 
                    / (2 * oricenter_to_sample * center_to_center)))
        degree = degree/180*math.pi
        scatter_x.append(oricenter_to_sample * np.cos(degree)-0.5)
        scatter_y.append(oricenter_to_sample * np.sin(degree))

    scatter_x = np.stack(scatter_x)
    scatter_y = np.stack(scatter_y)

    cmap = plt.cm.Blues
    norm = plt.Normalize(vmin=0, vmax=12) 

    fig, ax = plt.subplots(figsize=(10, 6))
    ax.scatter(-0.5, 0, color='orange', marker='p', s=250)
    ax.scatter(0.5, 0, color='orangered', marker='p', s=250)
    for i in range(scatter_x.shape[0]):
        ax.scatter(scatter_x[i, :], scatter_y[i, :], s=50, c=cmap(norm(range(4, 11))))

    ax.axes.xaxis.set_visible(False)
    ax.axes.yaxis.set_visible(False)

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])

    # cbar = plt.colorbar(sm)
    # cbar.set_ticks([0, 2, 4, 6, 8, 10, 12])
    # cbar.set_ticklabels([0, 1, 2, 4, 8, 16, 32])
    # cbar.ax.tick_params(labelsize=55)
    plt.tight_layout()
    plt.savefig('ad_sample/scatter.png')
    plt.clf()


# python adv_analysis.py --wandb_entity eavnjeong --arch resnet34 --bsz 100 --bsz_vl 100 --exp_load eph/cifar10_resnet34_lin_4 --head lin --dataset cifar10 --method evaluate
if __name__ == '__main__':
    args = get_general_args()
    init_wandb(args)
    eval = Evaluator(MLBase(args))
    main(eval)